'''Trains a simple convnet on the MNIST dataset and embeds test data.

The test data is embedded using the weights of the final dense layer, just
before the classification head. This embedding can then be visualized using
TensorBoard's Embedding Projector.
'''

from __future__ import print_function

from os import makedirs
from os.path import exists, join

import keras
from keras.callbacks import TensorBoard
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K

import numpy as np
# 每次batch的大小，显存小的话就小点，最好是2的n次方
batch_size = 128
# 最终分类数量
num_classes = 10
# 总共迭代总数据的次数
epochs = 50
# tensorboard使用的log目录
log_dir = './logs'

if not exists(log_dir):
    makedirs(log_dir)

# input image dimensions
img_rows, img_cols = 28, 28

# the data, split between train and test sets
# 下载数据集（如果是大数据集最好手动下载，放到.keras的datasets目录下）
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 调整通道顺序，这点很重要，不同框架需要的顺序不同，会影响训练
if K.image_data_format() == 'channels_first':
    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
else:
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)
# 类型转化，主要是为下面的除法做准备
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# save class labels to disk to color data points in TensorBoard accordingly
with open(join(log_dir, 'metadata.tsv'), 'w') as f:
    np.savetxt(f, y_test)

# convert class vectors to binary class matrices
# 将分类改成类似与one hot的格式
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
# 初始化tensorboard
# tensorboard = TensorBoard(batch_size=batch_size,
#                           embeddings_freq=1,
#                           embeddings_layer_names=['features'],
#                           embeddings_metadata='metadata.tsv',
#                           embeddings_data=x_test)

# 制作神经网络
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),
                 activation='relu',
                 input_shape=input_shape))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu', name='features'))
model.add(Dropout(0.5))
model.add(Dense(num_classes, activation='softmax'))

# 固定神经网络
model.compile(loss=keras.losses.categorical_crossentropy,
              optimizer=keras.optimizers.Adadelta(),
              metrics=['accuracy'])

# 开始训练
model.fit(x_train, y_train,
          batch_size=batch_size,
          # callbacks给了tensorboard，是为了可视化
        #   callbacks=[tensorboard],
          epochs=epochs,
          verbose=1,
          validation_data=(x_test, y_test))
# 训练完成进行一次评估
score = model.evaluate(x_test, y_test, verbose=0)
# 打印评估结果
print('Test loss:', score[0])
print('Test accuracy:', score[1])

# 以下是打开tensorboard的方法
# You can now launch tensorboard with `tensorboard --logdir=./logs` on your
# command line and then go to http://localhost:6006/#projector to view the
# embeddings